#!/usr/bin/env python3

"""
Generate src/bin/termptydbl.{c,h} from unicode files
used with ucd.all.flat.xml from
https://www.unicode.org/Public/UCD/latest/ucdxml/ucd.all.flat.zip
"""

import argparse
from collections import namedtuple
import xml.etree.ElementTree as ET

URange = namedtuple('unicode_range', ['width', 'start', 'end'])

def get_ranges(xmlfile):
    tree = ET.parse(xmlfile)
    root = tree.getroot()
    repertoire = root.find("{http://www.unicode.org/ns/2003/ucd/1.0}repertoire")
    chars = repertoire.findall("{http://www.unicode.org/ns/2003/ucd/1.0}char")

    ranges_basic = []
    ranges_emoji_double = []
    r_basic = URange('N', 0, 0)
    r_emoji_dbl = URange('N', 0, 0)
    for c in chars:
        ea = c.get('ea')
        if ea in ('Na', 'H'):
            ea = 'N'
        if ea in ('F'):
            ea = 'W'
        cp = c.get('cp')
        if not cp:
            continue
        cp = int(cp, 16)

        # basic
        if ea != r_basic[0]:
            ranges_basic.append(r_basic)
            r_basic = URange(ea, cp, cp)
        else:
            r_basic = r_basic._replace(end=cp)

        # emoji as wide
        ext_pic = c.get('ExtPict')
        emoji = c.get('Emoji')
        if emoji == 'Y' and ext_pic == 'Y' and ea != 'A':
            ea = 'W'
        else:
            blk = c.get('blk')
            if blk == 'Misc_Pictographs':
                ea = 'W'
        if ea != r_emoji_dbl[0]:
            ranges_emoji_double.append(r_emoji_dbl)
            r_emoji_dbl = URange(ea, cp, cp)
        else:
            r_emoji_dbl = r_emoji_dbl._replace(end=cp)

    ranges_basic.append(r_basic)
    ranges_emoji_double.append(r_emoji_dbl)

    return (ranges_basic, ranges_emoji_double)


def merge_ranges(ranges, is_same_width):
    res = []
    cur_range = ranges[0]
    for r in ranges:
        if is_same_width(r, cur_range):
            cur_range = cur_range._replace(end=r.end)
        else:
            res.append(cur_range)
            cur_range = r
    res.append(cur_range)
    return res

def skip_ranges(ranges, width_skipped):
    res = []
    for r in ranges:
        if r.width not in width_skipped:
            res.append(r)
    return res

def gen_header(mininum_codepoint, file_header):
    file_header.write(
"""/* XXX: Code generated by tool unicode_dbl_width.py */
#ifndef TERMINOLOGY_TERMPTY_DBL_H_
#define TERMINOLOGY_TERMPTY_DBL_H_ 1

Eina_Bool _termpty_is_wide(const Eina_Unicode g, Eina_Bool emoji_dbl_width);
Eina_Bool _termpty_is_ambigous_wide(const Eina_Unicode g, Eina_Bool emoji_dbl_width);

static inline Eina_Bool
_termpty_is_dblwidth_get(const Termpty *ty, const Eina_Unicode g)
{
   /* optimize for latin1 non-ambiguous */
""")
    file_header.write(f"   if (g <= 0x{mininum_codepoint:X})")
    file_header.write(
"""
     return EINA_FALSE;
   if (!ty->termstate.cjk_ambiguous_wide)
     return _termpty_is_wide(g, ty->config->emoji_dbl_width);
   else
     return _termpty_is_ambigous_wide(g, ty->config->emoji_dbl_width);
}

#endif
""")

def gen_ambigous(ranges_basic, ranges_emoji_double, file_source):
    def handle_ranges(ranges):
        def is_same_width(r1, r2):
            if r1.width == 'N':
                return r2.width == 'N'
            else:
                return r2.width in ('A', 'W')
        ranges = merge_ranges(ranges[1:], is_same_width)
        ranges = skip_ranges(ranges, ('N',))
        fallthrough = " EINA_FALLTHROUGH;"
        for idx, r in enumerate(ranges):
            if r.width == 'N':
                continue;
            if idx == len(ranges) -1:
                fallthrough = ""
            if r.start == r.end:
                file_source.write(f"           case 0x{r.start:X}:{fallthrough}\n")
            else:
                file_source.write(f"           case 0x{r.start:X} ... 0x{r.end:X}:{fallthrough}\n")

    file_source.write(
"""
__attribute__((const))
Eina_Bool
_termpty_is_ambigous_wide(Eina_Unicode g, Eina_Bool emoji_dbl_width)
{
   if (emoji_dbl_width)
     {
        switch (g)
          {
""")
    handle_ranges(ranges_emoji_double)
    file_source.write(
"""             return EINA_TRUE;
         }
     }
   else
     {
        switch (g)
          {
""")
    handle_ranges(ranges_basic)
    file_source.write(
"""             return EINA_TRUE;
          }
     }
   return EINA_FALSE;
}
""")


def gen_wide(ranges_basic, ranges_emoji_double, file_source):
    def handle_ranges(ranges):
        def is_same_width(r1, r2):
            if r1.width in ('N', 'A'):
                return r2.width in ('N', 'A')
            else:
                return r2.width == 'W'
        ranges = merge_ranges(ranges[1:], is_same_width)
        ranges = skip_ranges(ranges, ('N', 'A'))
        fallthrough = " EINA_FALLTHROUGH;"
        for idx, r in enumerate(ranges):
            if r.width in ('N', 'A'):
                continue;
            if idx == len(ranges) -1:
                fallthrough = ""
            if r.start == r.end:
                file_source.write(f"        case 0x{r.start:X}:{fallthrough}\n")
            else:
                file_source.write(f"        case 0x{r.start:X} ... 0x{r.end:X}:{fallthrough}\n")

    file_source.write(
"""
__attribute__((const))
Eina_Bool
_termpty_is_wide(Eina_Unicode g, Eina_Bool emoji_dbl_width)
{
   if (emoji_dbl_width)
     {
        switch (g)
          {
""")
    handle_ranges(ranges_emoji_double)
    file_source.write(
"""             return EINA_TRUE;
          }
     }
   else
     {
        switch (g)
          {
""")
    handle_ranges(ranges_basic)
    file_source.write(
"""             return EINA_TRUE;
          }
     }
   return EINA_FALSE;
}
""")


def gen_c(ranges_basic, ranges_emoji_double, file_header, file_source):
    mininum_codepoint = min(ranges_basic[0].end, ranges_emoji_double[0].end)
    gen_header(mininum_codepoint, file_header)
    file_source.write(
"""/* XXX: Code generated by tool unicode_dbl_width.py */
#include "private.h"

#include <Elementary.h>
#include "termpty.h"
#include "termptydbl.h"
""")
    gen_ambigous(ranges_basic, ranges_emoji_double, file_source)
    gen_wide(ranges_basic, ranges_emoji_double, file_source)

parser = argparse.ArgumentParser(description='Generate code handling different widths of unicode codepoints.')
parser.add_argument('xml', type=argparse.FileType('r'))
parser.add_argument('header', type=argparse.FileType('w'))
parser.add_argument('source', type=argparse.FileType('w'))

args = parser.parse_args()

(ranges_basic, ranges_emoji_double) = get_ranges(args.xml)
gen_c(ranges_basic, ranges_emoji_double, args.header, args.source)
